"""
Time                2023/09/19 15:41
Author:             mingfeng (SunnyQjm)
Email               mfeng@linux.alibaba.com
File                schemas.py
Description:
"""
from functools import lru_cache
from abc import abstractmethod
from typing import Any, Optional, Generic, Type, TypeVar, List, Union
from clogger import logger
from pydantic import BaseModel, create_model
from starlette.requests import Request
from sqlalchemy import desc, asc, and_
from sqlalchemy.orm import Session, Query
from sqlalchemy.sql import func
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.sqltypes import Integer, String, Enum, SmallInteger


#######################################################################################
# CRUD Helper
#######################################################################################


class BaseQueryParamsInterface:
    @abstractmethod
    def get_skip(self) -> Optional[int]:
        pass

    @abstractmethod
    def get_limit(self) -> Optional[int]:
        pass

    @abstractmethod
    def get_model_class(self) -> object:
        pass

    @abstractmethod
    def get_filters(self) -> Any:
        pass

    @abstractmethod
    def get_sorter(self) -> Any:
        pass


class QueryBuilder:
    def __init__(self, queryset: Query, query_params: BaseQueryParamsInterface) -> None:
        self.queryset: Query = queryset
        self.query_params = query_params

    def apply_filter(self) -> "QueryBuilder":
        filters = self.query_params.get_filters()
        if filters is not None:
            self.queryset = self.queryset.filter(filters)
        return self

    def apply_sorter(self) -> "QueryBuilder":
        sorter = self.query_params.get_sorter()
        if sorter is not None:
            self.queryset = self.queryset.order_by(sorter)
        return self

    def apply_offset(self) -> "QueryBuilder":
        skip = self.query_params.get_skip()
        if skip is not None:
            self.queryset = self.queryset.offset(skip)
        return self

    def apply_limit(self) -> "QueryBuilder":
        limit = self.query_params.get_limit()
        if limit is not None:
            self.queryset = self.queryset.limit(limit)
        return self

    def apply_paging(self) -> "QueryBuilder":
        return self.apply_offset().apply_limit()

    def build(self) -> Query:
        return self.queryset


class BaseQueryParams(BaseModel):
    current: int = 1
    pageSize: int = 10
    sort__: str = "-created_at"

    __modelclass__: Optional[object] = None

    def get_skip(self) -> Optional[int]:
        return (self.current - 1) * self.pageSize

    def get_limit(self) -> Optional[int]:
        return self.pageSize

    def get_model_class(self) -> object:
        if self.__modelclass__ is None:
            raise (Exception(f"{self.__class__} missing define __modelclass__"))
        return self.__modelclass__

    def get_filters(self) -> Any:
        filter_params = self.__dict__.copy()
        filter_params.pop("current", "")
        filter_params.pop("pageSize", "")
        filter_params.pop("sort__", "")

        filters = []
        for k, v in filter_params.items():
            if v is None:
                continue
            v = v.strip()
            values = v.split(",")

            # 忽略空参数
            if not v or len(values) <= 0:
                continue

            field_type = type(self.get_model_class().__dict__[k].type)
            if field_type in [Integer, SmallInteger]:
                values = [int(v_) for v_ in values]
            elif field_type == Enum:
                values = [
                    self.get_model_class()
                    .__dict__[k]
                    .type.__dict__["_object_lookup"][v_]
                    for v_ in values
                ]
            elif field_type == String:
                # String
                pass

            if len(values) > 1:
                # 过滤多个，eg.: alert_item=test5,test6
                filters.append(self.get_model_class().__dict__[k].in_(values))
            else:
                # 过滤单个，eg.: status=RESOLVED
                filters.append(self.get_model_class().__dict__[k] == values[0])

        if len(filters) > 0:
            return and_(*filters)
        else:
            return None

    def get_sorter(self) -> Any:
        if not self.sort__.strip():
            return None
        sort_key = self.sort__[1:] if self.sort__.startswith("-") else self.sort__
        if self.sort__.startswith("-"):
            # desc order by sort_key
            return desc(self.get_model_class().__dict__[sort_key])
        else:
            # asc order by sort_key
            return asc(self.get_model_class().__dict__[sort_key])

    def get_query_builder(self, db: Session) -> QueryBuilder:
        return QueryBuilder(db.query(self.get_model_class()), self)

    def get_count_by(self, db: Session, attr: InstrumentedAttribute) -> int:
        """Get the total number of entries

        Args:
            db (Session): _description_
            attr (InstrumentedAttribute): _description_

        Returns:
            int: total number of entries
        Reference:
            https://blog.csdn.net/chenhepg/article/details/105169255
        """
        exp = db.query(func.count(attr))
        filters = self.get_filters()
        if filters is not None:
            exp = exp.filter(filters)
        return exp.scalar()

    def get_query_exp(self, db: Session) -> Any:
        exp = db.query(self.get_model_class())
        filters = self.get_filters()
        if filters is not None:
            exp = exp.filter(filters)
        sorter = self.get_sorter()
        if sorter is not None:
            exp = exp.order_by(sorter)
        skip = self.get_skip()
        if skip is not None:
            exp = exp.offset(skip)
        limit = self.get_limit()
        if limit is not None:
            exp = exp.limit(limit)
        return exp


#######################################################################################
# Response Helper
#######################################################################################

# Model
M = TypeVar("M", bound=object)
# Schema
S = TypeVar("S", bound=BaseModel)


@lru_cache()
def get_standard_response_model(cls: Type[BaseModel]) -> Type[BaseModel]:
    """Standard reponse data format
    {
        "code": 200,
        "message": "",
        "data": {
            ...
        }
    }
    Args:
        cls (Type[BaseModel]): _description_

    Returns:
        Type[BaseModel]: _description_

    Reference:
        https://gist.github.com/wshayes/8e2341bb245a4125b294f6bd5da2df2d
    """
    assert issubclass(cls, BaseModel)
    return create_model(
        f"StandardData[{cls.__name__}]",
        code=(int, ...),
        message=(str, ...),
        data=(Optional[cls], {}),
    )


@lru_cache()
def get_standard_list_response_model(cls: Type[BaseModel]) -> Type[BaseModel]:
    """Standard list response data format
    {
        "code": 200,
        "message": "",
        "data": [
            ...
        ],
        "total": 12
    }

    Args:
        cls (Type[BaseModel]): _description_

    Returns:
        Type[BaseModel]: _description_
    """
    assert issubclass(cls, BaseModel)
    return create_model(
        f"StandardListData[{cls.__name__}]",
        code=(int, ...),
        message=(str, ...),
        data=(List[cls], []),
        total=(int, 0),
    )


EmptyStandardData = create_model(
    f"StandardData[None]",
    code=(int, ...),
    message=(str, ...),
    data=(Optional[object], {}),
    total=(int, 0),
)

EmptyStandardListData = create_model(
    f"StandardListData[None]",
    code=(int, ...),
    message=(str, ...),
    data=(List[object], []),
    total=(int, 0),
)


class StandardResponse(Generic[M, S]):
    __base_class: Optional[Type] = None

    @classmethod
    def bind_base_class(cls, base_class: object):
        cls.__base_class = base_class

    @classmethod
    def __get_base_class(cls):
        if cls.__base_class is None:
            raise Exception(
                "StandardResponse not initial, please call StandardResponse.bind_base_class first"
            )
        return cls.__base_class

    def __class_getitem__(cls, item):
        return get_standard_response_model(item)

    def __new__(
        cls,
        data: Optional[Union[S, M]],
        schema_class: Optional[Type[S]] = None,
        code: int = 200,
        message: str = "",
        request: Optional[Request] = None,
    ) -> "StandardResponse[M, S]":
        response_data: Optional[BaseModel]
        if data is None:
            response_type = EmptyStandardData
            response_data = {}
        elif isinstance(data, BaseModel):
            response_type = get_standard_response_model(type(data))
            response_data = data
        elif isinstance(data, cls.__get_base_class()):
            if schema_class is None:
                raise Exception("data is orm model object, required pass schema_class")
            response_type = get_standard_response_model(schema_class)
            response_data = schema_class.from_orm(data)
        else:
            print(data, type(data))
            raise Exception("data must be orm model object or pydantic object")
        return response_type(code=code, message=message, data=response_data)

    @classmethod
    def success(
        cls, data: Union[S, M], schema_class: Optional[Type[S]] = None, code: int = 200
    ) -> "StandardResponse[M, S]":
        return cls.__new__(cls, data, schema_class=schema_class, code=code)

    @classmethod
    def error(cls, message: str, code: int = 500) -> "StandardResponse[M, S]":
        return cls.__new__(cls, data=None, message=message, code=code)


class StandardListResponse(Generic[M, S]):
    __base_class: Optional[Type] = None

    @classmethod
    def bind_base_class(cls, base_class: object):
        cls.__base_class = base_class

    @classmethod
    def __get_base_class(cls):
        if cls.__base_class is None:
            raise Exception(
                "StandardListResponse not initial, please call StandardResponse.bind_base_class first"
            )
        return cls.__base_class

    def __class_getitem__(cls, item):
        return get_standard_list_response_model(item)

    def __new__(
        cls,
        data: Optional[List[M]],
        schema_class: Optional[Type[S]] = None,
        code: int = 200,
        message: str = "",
        total: int = 0,
        request: Optional[Request] = None,
    ) -> "StandardResponse[M, S]":
        response_data: Optional[List[BaseModel]]
        if data is None or len(data) == 0:
            response_type = EmptyStandardListData
            response_data = []
        elif isinstance(data[0], BaseModel):
            response_type = get_standard_list_response_model(type(data[0]))
            response_data = data
        elif isinstance(data[0], cls.__get_base_class()):
            if schema_class is None:
                raise Exception("data is orm model object, required pass schema_class")
            response_type = get_standard_list_response_model(schema_class)
            response_data = list(map(schema_class.from_orm, data))
        else:
            raise Exception("data must be orm model object or pydantic object")
        if total == 0:
            total = len(response_data)
        return response_type(
            code=code, message=message, data=response_data, total=total
        )

    @classmethod
    def success(
        cls, data: Union[S, M], schema_class: Optional[Type[S]] = None, code: int = 200
    ) -> "StandardResponse[M, S]":
        return cls.__new__(cls, data, schema_class=schema_class, code=code)

    @classmethod
    def error(cls, message: str, code: int = 500) -> "StandardResponse[M, S]":
        return cls.__new__(cls, data=None, message=message, code=code)


class FastApiResponseHelper:
    @staticmethod
    def bind_base_class(base_class: object):
        StandardResponse.bind_base_class(base_class)
        StandardListResponse.bind_base_class(base_class)
